{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github"
      },
      "source": [
        "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/camenduru/text-to-video-synthesis-colab/blob/main/watermark_remover_gradio.ipynb)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "VSNbyJrXkQXb"
      },
      "outputs": [],
      "source": [
        "!pip install -q torch==1.13.1+cu116 torchvision==0.14.1+cu116 torchaudio==0.13.1 torchtext==0.14.1 torchdata==0.5.1 --extra-index-url https://download.pytorch.org/whl/cu116 -U\n",
        "!pip install -q gradio==3.26.0 pytorch-lightning==1.8.3.post0 hydra-core==1.1.0 webdataset kornia==0.5.0\n",
        "!git clone -b dev https://github.com/camenduru/lama\n",
        "%cd /content/lama\n",
        "\n",
        "!mkdir /content/lama/big-lama\n",
        "!mkdir /content/lama/big-lama/models\n",
        "!mkdir /content/lama/masks\n",
        "!wget https://huggingface.co/camenduru/big-lama/resolve/main/big-lama/models/best.ckpt -O /content/lama/big-lama/models/best.ckpt\n",
        "!wget https://huggingface.co/camenduru/big-lama/raw/main/big-lama/config.yaml -O /content/lama/big-lama/config.yaml\n",
        "!wget https://huggingface.co/camenduru/big-lama/resolve/main/big-lama/mask.png -O /content/lama/masks/test_mask.png\n",
        "\n",
        "!mkdir /content/lama/indir\n",
        "!mkdir /content/lama/outdir\n",
        "!wget https://huggingface.co/camenduru/big-lama/resolve/main/big-lama/image_mask.png -O /content/lama/indir/image_mask.png\n",
        "\n",
        "#https://huggingface.co/spaces/fffiloni/lama-video-watermark-remover/blob/main/app.py modified\n",
        "\n",
        "%cd /content/lama\n",
        "\n",
        "from moviepy.editor import *\n",
        "import os, cv2\n",
        "import gradio as gr\n",
        "from PIL import Image, ImageOps\n",
        "from moviepy.editor import *\n",
        "\n",
        "def get_frames(video_in):\n",
        "    frames = []\n",
        "    clip = VideoFileClip(video_in)\n",
        "    if clip.fps > 30:\n",
        "        print(\"vide rate is over 30, resetting to 30\")\n",
        "        clip_resized = clip.resize(height=256)\n",
        "        clip_resized.write_videofile(\"video_resized.mp4\", fps=30)\n",
        "    else:\n",
        "        print(\"video rate is OK\")\n",
        "        clip_resized = clip.resize(height=256)\n",
        "        clip_resized.write_videofile(\"video_resized.mp4\", fps=clip.fps)\n",
        "    print(\"video resized to 512 height\")\n",
        "    cap=cv2.VideoCapture(\"video_resized.mp4\")\n",
        "    fps=cap.get(cv2.CAP_PROP_FPS)\n",
        "    print(\"video fps: \" + str(fps))\n",
        "    i=0\n",
        "    while(cap.isOpened()):\n",
        "        ret, frame = cap.read()\n",
        "        if ret == False:\n",
        "            break\n",
        "        cv2.imwrite('in'+str(i)+'.jpg',frame)\n",
        "        frames.append('in'+str(i)+'.jpg')\n",
        "        i+=1\n",
        "    cap.release()\n",
        "    cv2.destroyAllWindows()\n",
        "    print(\"broke the video into frames\")\n",
        "    return frames, fps\n",
        "\n",
        "def create_video(frames, fps, type):\n",
        "    print(\"building video result\")\n",
        "    clip = ImageSequenceClip(frames, fps=fps)\n",
        "    clip.write_videofile(type + \"_result.mp4\", fps=fps)\n",
        "    return type + \"_result.mp4\"\n",
        "\n",
        "def lama(img):\n",
        "    print(img)\n",
        "    img = Image.open(img)\n",
        "    imageio.imwrite(f\"/content/lama/indir/image.png\", img)\n",
        "    os.system('python predict.py model.path=/content/lama/big-lama indir=/content/lama/indir outdir=/content/lama/outdir device=cuda')\n",
        "    return f\"/content/lama/outdir/image_mask.png\"\n",
        "\n",
        "def infer(video_in):\n",
        "    break_vid = get_frames(video_in)\n",
        "    frames_list= break_vid[0]\n",
        "    fps = break_vid[1]\n",
        "    n_frame = len(frames_list)\n",
        "    if n_frame >= len(frames_list):\n",
        "        print(\"video is shorter than the cut value\")\n",
        "        n_frame = len(frames_list)\n",
        "    result_frames = []\n",
        "    print(\"set stop frames to: \" + str(n_frame))\n",
        "    for i in frames_list[0:int(n_frame)]:\n",
        "        lama_frame = lama(i)\n",
        "        lama_frame = Image.open(lama_frame)\n",
        "        imageio.imwrite(f\"out-{i}\", lama_frame)\n",
        "        result_frames.append(f\"out-{i}\")\n",
        "        print(\"frame \" + i + \"/\" + str(n_frame) + \": done;\")\n",
        "    final_vid = create_video(result_frames, fps, \"cleaned\")\n",
        "    files = [final_vid]\n",
        "    return final_vid\n",
        "\n",
        "with gr.Blocks() as demo:\n",
        "    with gr.Column(elem_id=\"col-container\"):\n",
        "        with gr.Row():\n",
        "            with gr.Column():\n",
        "                inputs = gr.Video(label=\"Input\", source=\"upload\", type=\"filepath\")\n",
        "            with gr.Column():\n",
        "                outputs = gr.Video(label=\"output\")\n",
        "                submit_btn = gr.Button(\"Remove Watermark\")\n",
        "        submit_btn.click(infer, inputs=[inputs], outputs=[outputs])\n",
        "demo.queue().launch(debug=True, share=True, inline=False)"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "provenance": []
    },
    "gpuClass": "standard",
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
